import os
import sys
dir2 = os.path.abspath('')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path:
    sys.path.append(dir1)
import json
import numpy as np
from models.model_utils import get_model
from data.data_utils import get_dataloader
import torch
import matplotlib.pyplot as plt
from torch import nn 
import torch.nn.functional as F
print('finish import')


def all_feature_mean(rep):
    mean = []
    for i in range(rep.shape[1]):
        feature_rep = rep[:, i].reshape((rep.shape[0], -1))
        mean.append(feature_rep.mean())
    return np.array(mean)


def all_feature_std(rep):
    std = []
    for i in range(rep.shape[1]):
        feature_rep = rep[:, i].reshape((rep.shape[0], -1))
        std.append(feature_rep.std())
    return np.array(std)


def get_all_repr(model, data_loader, layer_name):
    representations = []
    for x, _ in data_loader:
        x = x.cuda()
        rep = model.get_intermediate(x)[layer_name]
        representations.append(rep.data.cpu().numpy())
    return np.concatenate(representations, axis=0)


def get_representation(model, dataloader, layer_name, neuron_index):
    representations = []
    for x, _ in dataloader:
        x = x.cuda()
        rep = model.get_intermediate(x)[layer_name]
        rep = rep[:, neuron_index]
        representations.append(rep.data.cpu().numpy())
    return np.concatenate(representations, axis=0)


def compute_statistics(model_1, model_2, dataloader, layer_name, neuron_index_1, neuron_index_2, verbose=False):
    if verbose:
        print('Computing representation.')
    model_1_repr = get_representation(model_1, dataloader, layer_name, neuron_index_1)
    model_2_repr = get_representation(model_2, dataloader, layer_name, neuron_index_2)
    if verbose:
        print('Finished computing representation.')
    model_1_repr = model_1_repr.reshape((model_1_repr.shape[0], -1))
    print(model_1_repr[0])
    model_2_repr = model_2_repr.reshape((model_2_repr.shape[0], -1))
    print(model_2_repr[0])
    model_1_mean = model_1_repr.mean()
    model_2_mean = model_2_repr.mean()
    model_1_std = model_1_repr.reshape(-1).std()
    model_2_std = model_2_repr.reshape(-1).std()
    correlation = ((model_1_repr - model_1_mean)*(model_2_repr - model_2_mean)).mean()
    unnormalized_correlation = correlation
    correlation /= model_1_std * model_2_std
    return {
        'model_1_mean': model_1_mean,
        'model_2_mean': model_2_mean,
        'model_1_std': model_1_std,
        'model_2_std': model_2_std,
        'correlation': correlation,
        'unnormalized_correlation': unnormalized_correlation
    }


if __name__ == '__main__':
    print(f'Cuda {torch.cuda.is_available()}')
    root = None  # the root for all the checkpoints
    exp_name = 'Resnet18_CIFAR10_10k_100_COPIES'
    exp_dir = root.format(exp_name)
    model_1_path = os.path.join(exp_dir, 'model_17')
    model_2_path = os.path.join(exp_dir, 'model_11')
    config_path = os.path.join(model_1_path, 'config')
    weight_1_path = os.path.join(model_1_path, 'weight.pt')
    weight_2_path = os.path.join(model_2_path, 'weight.pt')
    print('Loading config...')
    with open(config_path, 'r') as f:
        config = json.load(f)
    print(config)
    model_1 = get_model(config).cuda()
    print('Model 1 created')
    state_dict = torch.load(weight_1_path, map_location=torch.device('gpu'))
    print('Reading weights from disk.')
    model_1.load_state_dict(state_dict)
    print('Model 1 loaded...')
    model_2 = get_model(config).cuda()
    state_dict = torch.load(weight_2_path, map_location=torch.device('gpu'))
    model_2.load_state_dict(state_dict)
    print('All models loaded!')
    # model_1.cuda()
    # model_2.cuda()
    # print('moving to cuda')
    print('getting data loader')
    all_dataloaders = get_dataloader(config)
    print('get data loader')
    _, test_loader = all_dataloaders[0]
    stats = compute_statistics(
        model_1=model_1,
        model_2=model_2,
        dataloader=test_loader,
        layer_name='layer4',
        neuron_index_1=0,
        neuron_index_2=0
    )
    print(stats)
